/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.fa;

import java.util.Arrays;
import java.util.List;
import java.util.Random;

import mosdi.util.SequenceUtils;

public class MarkovianTextModel extends FiniteMemoryTextModel {
	private int order;
	private int alphabetSize;
	private int stateCount;
	// every state represents one k-mer (where k is the order of the model)
	// e.g. we have alphabetSize^k states, ordered naturally.
	// this table contains for each state the targetState under character 0
	// (other characters go to subsequent states)
	private int[] transitions;
	// corresponding probabilities (for each state and each character the probability
	// is stored).
	private double[][] transitionProbs;
	// each state's equilibrium probability
	private double[] eqDistribution;
	private double[] cumulativeEqDistribution;

	private MarkovianTextModel reverseModel;
	
	/** Constructor.
	 * 
	 * @param order Order of the model (i.e. length of relevant history)
	 * @param text Sample text from which the empiric distribution is to be estimated.
	 */
	public MarkovianTextModel(int order, int alphabetSize, int[] text) {
		this(order, alphabetSize, text, 0.0d);
	}

	/** Constructor.
	 * 
	 * @param order Order of the model (i.e. length of relevant history)
	 * @param text Sample text from which the empiric distribution is to be estimated.
	 */
	public MarkovianTextModel(int order, int alphabetSize, int[] text, double pseudoCounts) {
		if (order<1) throw new IllegalArgumentException("Order must be >=1.");
		this.order = order;
		this.alphabetSize = alphabetSize;
		double[] wordFrequencies = new double[(int)Math.pow(alphabetSize, order+1)];
		Arrays.fill(wordFrequencies, pseudoCounts);
		SequenceUtils.countQGrams(order+1, alphabetSize, text, wordFrequencies);
		buildTables(wordFrequencies);
		reverseModel = new MarkovianTextModel(order,alphabetSize,revertWordFrequencies(wordFrequencies),this);
	}
	
	/** Constructor.
	 * 
	 * @param order Order of the model (i.e. length of relevant history)
	 * @param qGramCounts Counts of q-grams for q=order+1 as returned by SequenceUtils.countQGrams().
	 */
	public MarkovianTextModel(int order, int alphabetSize, double[] qGramCounts) {
		if (order<1) throw new IllegalArgumentException("Order must be >=1.");
		if (qGramCounts.length!=(int)Math.pow(alphabetSize, order+1)) throw new IllegalArgumentException("Length of q-gram table must be alphabetSize^(order+1).");
		this.order = order;
		this.alphabetSize = alphabetSize;
		buildTables(qGramCounts);
		reverseModel = new MarkovianTextModel(order,alphabetSize,revertWordFrequencies(qGramCounts),this);
	}

	/** Constructor.
	 * 
	 * @param order Order of the model (i.e. length of relevant history)
	 * @param text Sample text from which the empiric distribution is to be estimated.
	 */
	public MarkovianTextModel(int order, Alphabet alphabet, String text) {
		this(order,alphabet.size(),alphabet.buildIndexArray(text),0.0d);
	}
	
	/** Constructor.
	 * 
	 * @param order Order of the model (i.e. length of relevant history)
	 * @param text Sample text from which the empiric distribution is to be estimated.
	 */
	public MarkovianTextModel(int order, Alphabet alphabet, String text, double pseudoCounts) {
		this(order,alphabet.size(),alphabet.buildIndexArray(text),pseudoCounts);
	}

	/** Constructor.
	 * 
	 * @param order Order of the model (i.e. length of relevant history)
	 * @param texts Sample texts from which the empiric distribution is to be estimated.
	 */
	public MarkovianTextModel(int order, int alphabetSize, List<int[]> texts) {
		this(order, alphabetSize, texts, 0.0d);
	}

	/** Constructor.
	 * 
	 * @param order Order of the model (i.e. length of relevant history)
	 * @param texts Sample texts from which the empiric distribution is to be estimated.
	 */
	public MarkovianTextModel(int order, int alphabetSize, List<int[]> texts, double pseudoCounts) {
		if (order<1) throw new IllegalArgumentException("Order must be >=1.");
		this.order = order;
		this.alphabetSize = alphabetSize;
		double[] wordFrequencies = new double[(int)Math.pow(alphabetSize, order+1)];
		Arrays.fill(wordFrequencies, pseudoCounts);
		for (int[] text : texts) {
			SequenceUtils.countQGrams(order+1, alphabetSize, text, wordFrequencies);
		}
		buildTables(wordFrequencies);
		reverseModel = new MarkovianTextModel(order,alphabetSize,revertWordFrequencies(wordFrequencies),this);
	}

	/** Constructor.
	 * 
	 * @param order Order of the model (i.e. length of relevant history)
	 * @param texts Sample texts from which the empiric distribution is to be estimated.
	 */
	public MarkovianTextModel(int order, Alphabet alphabet, List<String> texts) {
		this(order, alphabet, texts, 0.0d);
	}

	/** Constructor.
	 * 
	 * @param order Order of the model (i.e. length of relevant history)
	 * @param texts Sample texts from which the empiric distribution is to be estimated.
	 */
	public MarkovianTextModel(int order, Alphabet alphabet, List<String> texts, double pseudoCounts) {
		if (order<1) throw new IllegalArgumentException("Order must be >=1.");
		this.order = order;
		this.alphabetSize = alphabet.size();
		double[] wordFrequencies = new double[(int)Math.pow(alphabetSize, order+1)];
		Arrays.fill(wordFrequencies, pseudoCounts);
		for (String text : texts) {
			SequenceUtils.countQGrams(order+1, alphabetSize, alphabet.buildIndexArray(text), wordFrequencies);
		}
		buildTables(wordFrequencies);
		reverseModel = new MarkovianTextModel(order,alphabetSize,revertWordFrequencies(wordFrequencies),this);
	}

	/** Constructor.
	 * 
	 * @param order Order of the model (i.e. length of relevant history)
	 * @param wordFrequencies Table of frequencies (in arbitrary units) of (order+1)-grams
	 *                        in lexicographic order.
	 */
	private MarkovianTextModel(int order, int alphabetSize, double[] wordFrequencies, MarkovianTextModel reverseModel) {
		if (order<1) throw new IllegalArgumentException("Order must be >=1.");
		if (wordFrequencies.length!=(int)Math.pow(alphabetSize,(order+1))) throw new IllegalArgumentException("Array wordFrequencies has invalid length.");
		this.order = order;
		this.alphabetSize = alphabetSize;
		buildTables(wordFrequencies);
		this.reverseModel = reverseModel;
	}
	
	@Override
	public FiniteMemoryTextModel reverseTextModel() {
		return reverseModel;
	}

	private int revertWordIndex(int index) {
		int revIndex = 0;
		for (int i=0; i<=order; ++i) {
			revIndex*=alphabetSize;
			revIndex+=index % alphabetSize;
			index/=alphabetSize;
		}
		return revIndex;
	}
	
	private double[] revertWordFrequencies(double[] wordFrequencies) {
		double[] result = new double[wordFrequencies.length];
		for (int i=0; i<wordFrequencies.length; ++i) {
			result[revertWordIndex(i)]=wordFrequencies[i];
		}
		return result;
	}
	
	private void buildTables(double[] wordFrequencies) {
		// Log.println(Log.Level.DEBUG, printWordFrequencies(wordFrequencies));
		stateCount = (int)Math.pow(alphabetSize,order);
		transitions = new int[stateCount];
		transitionProbs = new double[stateCount][];
		int wordIndex = 0;
		double totalSum = 0.0;
		eqDistribution = new double[stateCount];
		int h = ((int)Math.pow(alphabetSize, order-1));
		for (int state=0; state<stateCount; ++state) {
			transitions[state] = (state%h)*alphabetSize;
			double sum = 0.0;
			for (int c=0; c<alphabetSize; ++c) sum+=wordFrequencies[wordIndex+c];
			transitionProbs[state] = new double[alphabetSize];
			for (int c=0; c<alphabetSize; ++c) transitionProbs[state][c]=wordFrequencies[wordIndex++]/sum;
			eqDistribution[state]=sum;
			totalSum+=sum;
		}
		
		
		double tmpProb = 0.0d; 
		for (int i = 0; i < stateCount; ++i) {
			for (int j = 0; j < alphabetSize; ++j) {
				tmpProb = getProbability(i, j);
				if (tmpProb == 0.0d) {
					throw new IllegalArgumentException("The Markovian Modell has a transition (State, character) = ("+i+","+j+") with a value equals 0");
					// throw new RuntimeException("Wahrscheinlichkeit vom
					// Zustand "+i+" mit Zeichen "+alphabet.get(j)+"
					// weiterzukommen ist 0");
				} 
			}
		}
		for (int i=0; i<stateCount; ++i) eqDistribution[i]/=totalSum;
		convergeToEquilibrium(1e-13);
	}
	
	private void convergeToEquilibrium(double accuracy) {
		// Log.println(Log.Level.DEBUG, Arrays.toString(eqDistribution));
		int n = 0;
		final int maxsteps = 10000;
		while (true) {
			double[] p = new double[eqDistribution.length];
			for (int state=0; state<stateCount; ++state) {
				for (int c=0; c<alphabetSize; ++c) {
					p[getTransitionTarget(state,c)]+=eqDistribution[state]*transitionProbs[state][c];
				}
			}
			// check if equilibrium is reached
			double maxdiff = 0.0;
			for (int state=0; state<stateCount; ++state) {
				double diff = Math.abs(eqDistribution[state]-p[state]);
				if (diff>maxdiff) maxdiff=diff;
			}
			eqDistribution=p;
			// Log.printf(Log.Level.DEBUG, "Convergence after step %d: %e%n", n, maxdiff);
			if (maxdiff<=accuracy) break;
			if (n>=maxsteps) throw new IllegalStateException("Markov chain does not converge!");
			n+=1;
		}
		// Log.println(Log.Level.DEBUG, Arrays.toString(eqDistribution));
	}
	
	@Override
	public int getAlphabetSize() {
		return alphabetSize;
	}

	/** Returns the string to which the given state corresponds. (Each state
	 *  corresponds to a string of length "order". If in this state, the returned
	 *  string matches the lastly generated characters. */
	public int[] getContext(int state) {
		int[] s = new int[order];
		for (int i=0; i<order; ++i) {
			int c = (state/(int)Math.pow(alphabetSize, order-i-1))%alphabetSize;
			s[i] = c;
		}
		return s;
	}

	public int getTransitionTarget(int sourceState, int character) {
		return transitions[sourceState] + character;
	}

	/** Probability that, in the given state, the given character is generated. */
	@Override
	public double getProbability(int state, int character) {
		return transitionProbs[state][character];
	}
	
	@Override
	public double getProbability(int sourceState, int character, int targetState) {
		if (transitions[sourceState] + character != targetState) return 0.0;
		return transitionProbs[sourceState][character];
	}

	@Override
	public int[] getTransitionTargets(int sourceState, int character) {
		int[] result = {transitions[sourceState] + character};
		return result;
	}

	public int getOrder() { return order; }
	
	/** Given a value between 0.0 and 1.0, find the corresponding state (w.r.t. to
	 *  the cumulative distribution function). */
	private int findState(double p) { 
		if (cumulativeEqDistribution==null) {
			cumulativeEqDistribution = new double[stateCount];
			cumulativeEqDistribution[0] = eqDistribution[0];
			for (int i=1; i<stateCount; ++i) {
				cumulativeEqDistribution[i] = cumulativeEqDistribution[i-1] + eqDistribution[i];
			}
			cumulativeEqDistribution[stateCount-1]=1.0;
		}
		// binary search for the smallest index i, s.t. cumulativeEqDistribution[i]>=p
		if (cumulativeEqDistribution[0]>=p) return 0;
		int l = 0;
		int r = stateCount-1;
		while (true) {
			int k = (l+r+1)/2;
			if (cumulativeEqDistribution[k]<p) l = k;
			else {
				if (cumulativeEqDistribution[k-1]<p) return k; 
				r = k;
			}
		}
	}
	
	public int[] generateRandomText(int length) {
		int[] s = new int[length];
		Random random = new Random();
		// choose start
		int state = findState(random.nextDouble());
		s[0] = state%alphabetSize;
		for (int i=1; i<length; ++i) {
			double p = random.nextDouble();
			int c = 0;
			while ((p>transitionProbs[state][c]) && (c<alphabetSize)) {
				p-=transitionProbs[state][c];
				c+=1;
			}
			state = getTransitionTarget(state, c);
			s[i] = state%alphabetSize;
		}
		return s;
	}

	@Override
	public int getStateCount() { return stateCount; }
	
	@Override
	public double getEquilibriumProbability(int state) { return eqDistribution[state]; }

	@Override
	public double[] getEquilibriumDistribution() {
		return Arrays.copyOf(eqDistribution, eqDistribution.length);
	}

	@Override
	public String toString() {
		StringBuffer sb = new StringBuffer();
		sb.append("state context character probability\n");
		for (int state=0; state<getStateCount(); ++state) {
			for (int c=0; c<alphabetSize; ++c) {
				sb.append(String.format("%d ",state));
				for (int i : getContext(state)) sb.append(i);
				sb.append(String.format(" %d %f\n",c,getProbability(state, c)));
			}
		}
		return sb.toString();
	}
	
	@Override
	public int order() {
		return order;
	}

}
